{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "tf.keras.backend.set_floatx('float32')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 不包含model的简单版本"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gradients =  tf.Tensor([[1.]], shape=(1, 1), dtype=float32)\n",
      "x before update =  <tf.Variable 'Variable:0' shape=(1, 1) dtype=float32, numpy=array([[1.]], dtype=float32)>\n",
      "x after update =  <tf.Variable 'Variable:0' shape=(1, 1) dtype=float32, numpy=array([[0.99]], dtype=float32)>\n"
     ]
    }
   ],
   "source": [
    "def process(x):\n",
    "    return 10 * x\n",
    "\n",
    "def lossfunction(x):\n",
    "    return x - 5\n",
    "\n",
    "x = tf.Variable([[1]], dtype=tf.float32)\n",
    "with tf.GradientTape(watch_accessed_variables=False) as tape:\n",
    "    # watch的监控对象必须是tf.Variable类型\n",
    "    tape.watch(x)\n",
    "    # loss = y - 5 = 10 * x - 5，梯度应该是10\n",
    "    y = process(x)\n",
    "    loss = lossfunction(x)\n",
    "gradients = tape.gradient(loss, x)\n",
    "print ('gradients = ', gradients)\n",
    "optimizer = tf.keras.optimizers.SGD()\n",
    "print ('x before update = ', x)\n",
    "optimizer.apply_gradients([(gradients, x)])\n",
    "# 默认lr = 0.01，因此x = x - lr * gradients = 1 - 0.01 * 1 = 0.99\n",
    "print ('x after update = ', x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 使用keras提供的loss function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gradients =  tf.Tensor([[10.]], shape=(1, 1), dtype=float32)\n",
      "x before update =  <tf.Variable 'Variable:0' shape=(1, 1) dtype=float32, numpy=array([[10.]], dtype=float32)>\n",
      "x after update =  <tf.Variable 'Variable:0' shape=(1, 1) dtype=float32, numpy=array([[9.9]], dtype=float32)>\n"
     ]
    }
   ],
   "source": [
    "def process(x):\n",
    "    return x\n",
    "\n",
    "x = tf.Variable([[10]], dtype=tf.float32)\n",
    "with tf.GradientTape(watch_accessed_variables=False) as tape:\n",
    "    # watch的监控对象必须是tf.Variable类型\n",
    "    tape.watch(x)\n",
    "    y = process(x)\n",
    "    loss = tf.keras.losses.mse(y,tf.constant(5, tf.float32))\n",
    "gradients = tape.gradient(loss, x)\n",
    "print ('gradients = ', gradients)\n",
    "optimizer = tf.keras.optimizers.SGD()\n",
    "print ('x before update = ', x)\n",
    "optimizer.apply_gradients([(gradients, x)])\n",
    "print ('x after update = ', x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 包含frozed模型，更新x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gradients =  tf.Tensor([[100.]], shape=(1, 1), dtype=float32)\n",
      "x before update =  <tf.Variable 'Variable:0' shape=(1, 1) dtype=float32, numpy=array([[1.]], dtype=float32)>\n",
      "x after update =  <tf.Variable 'Variable:0' shape=(1, 1) dtype=float32, numpy=array([[0.]], dtype=float32)>\n"
     ]
    }
   ],
   "source": [
    "class ExampleRandomNormal(tf.keras.initializers.Initializer):\n",
    "\n",
    "    def __init__(self, weights):\n",
    "        self.weights = weights\n",
    "\n",
    "    def __call__(self, shape, dtype=None):\n",
    "        return self.weights\n",
    "\n",
    "    def get_config(self):  # To support serialization\n",
    "        return {'weights': self.weights}\n",
    "\n",
    "# y = 10 * x\n",
    "layer = tf.keras.layers.Dense(1, kernel_initializer=ExampleRandomNormal([[10]]), input_shape=[None, 1])\n",
    "model = tf.keras.Sequential([layer])\n",
    "\n",
    "x = tf.Variable([[1]], dtype=tf.float32)\n",
    "with tf.GradientTape(watch_accessed_variables=False) as tape:\n",
    "    # watch的监控对象必须是tf.Variable类型\n",
    "    tape.watch(x)\n",
    "    # y = 10 * x，梯度应该是10\n",
    "    y = model(x)\n",
    "    loss = tf.keras.losses.mse(y,tf.constant(5, tf.float32))\n",
    "gradients = tape.gradient(loss, x)\n",
    "print ('gradients = ', gradients)\n",
    "optimizer = tf.keras.optimizers.SGD()\n",
    "print ('x before update = ', x)\n",
    "optimizer.apply_gradients([(gradients, x)])\n",
    "print ('x after update = ', x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 更新模型参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gradients =  [<tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[1.]], dtype=float32)>, <tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.], dtype=float32)>]\n",
      "model parameters before update =  [<tf.Variable 'dense_17/kernel:0' shape=(1, 1) dtype=float32, numpy=array([[10.]], dtype=float32)>, <tf.Variable 'dense_17/bias:0' shape=(1,) dtype=float32, numpy=array([0.], dtype=float32)>]\n",
      "model parameters after update =  [<tf.Variable 'dense_17/kernel:0' shape=(1, 1) dtype=float32, numpy=array([[9.99]], dtype=float32)>, <tf.Variable 'dense_17/bias:0' shape=(1,) dtype=float32, numpy=array([-0.01], dtype=float32)>]\n"
     ]
    }
   ],
   "source": [
    "class ExampleRandomNormal(tf.keras.initializers.Initializer):\n",
    "\n",
    "    def __init__(self, weights):\n",
    "        self.weights = weights\n",
    "\n",
    "    def __call__(self, shape, dtype=None):\n",
    "        return self.weights\n",
    "\n",
    "    def get_config(self):  # To support serialization\n",
    "        return {'weights': self.weights}\n",
    "\n",
    "# y = 10 * x\n",
    "layer = tf.keras.layers.Dense(1, kernel_initializer=ExampleRandomNormal([[10]]), input_shape=[None, 1])\n",
    "model = tf.keras.Sequential([layer])\n",
    "\n",
    "x = tf.Variable([[1]], dtype=tf.float32)\n",
    "with tf.GradientTape(watch_accessed_variables=True) as tape:\n",
    "    y = model(x)\n",
    "    loss = y\n",
    "gradients = tape.gradient(loss, model.trainable_variables)\n",
    "print ('gradients = ', gradients)\n",
    "print ('model parameters before update = ', model.weights)\n",
    "optimizer.apply_gradients(zip(gradients, model.trainable_variables))\n",
    "print ('model parameters after update = ', model.weights)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 更新多个模型的参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Model was constructed with shape Tensor(\"dense_21_input:0\", shape=(None, None, 1), dtype=float32) for input (None, None, 1), but it was re-called on a Tensor with incompatible shape (1, 1).\n",
      "gradients =  [<tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[5.]], dtype=float32)>, <tf.Tensor: shape=(1,), dtype=float32, numpy=array([5.], dtype=float32)>]\n",
      "model1 parameters before update =  [<tf.Variable 'dense_20/kernel:0' shape=(1, 1) dtype=float32, numpy=array([[10.]], dtype=float32)>, <tf.Variable 'dense_20/bias:0' shape=(1,) dtype=float32, numpy=array([0.], dtype=float32)>]\n",
      "model2 parameters before update =  [<tf.Variable 'dense_21/kernel:0' shape=(1, 1) dtype=float32, numpy=array([[5.]], dtype=float32)>, <tf.Variable 'dense_21/bias:0' shape=(1,) dtype=float32, numpy=array([0.], dtype=float32)>]\n",
      "model parameters after update =  [<tf.Variable 'dense_20/kernel:0' shape=(1, 1) dtype=float32, numpy=array([[9.95]], dtype=float32)>, <tf.Variable 'dense_20/bias:0' shape=(1,) dtype=float32, numpy=array([-0.05], dtype=float32)>]\n",
      "model parameters after update =  [<tf.Variable 'dense_21/kernel:0' shape=(1, 1) dtype=float32, numpy=array([[5.]], dtype=float32)>, <tf.Variable 'dense_21/bias:0' shape=(1,) dtype=float32, numpy=array([0.], dtype=float32)>]\n",
      "gradients =  [<tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[10.]], dtype=float32)>, <tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.], dtype=float32)>]\n",
      "model1 parameters before update =  [<tf.Variable 'dense_20/kernel:0' shape=(1, 1) dtype=float32, numpy=array([[9.95]], dtype=float32)>, <tf.Variable 'dense_20/bias:0' shape=(1,) dtype=float32, numpy=array([-0.05], dtype=float32)>]\n",
      "model2 parameters before update =  [<tf.Variable 'dense_21/kernel:0' shape=(1, 1) dtype=float32, numpy=array([[5.]], dtype=float32)>, <tf.Variable 'dense_21/bias:0' shape=(1,) dtype=float32, numpy=array([0.], dtype=float32)>]\n",
      "model parameters after update =  [<tf.Variable 'dense_20/kernel:0' shape=(1, 1) dtype=float32, numpy=array([[9.849999]], dtype=float32)>, <tf.Variable 'dense_20/bias:0' shape=(1,) dtype=float32, numpy=array([-0.05999999], dtype=float32)>]\n",
      "model parameters after update =  [<tf.Variable 'dense_21/kernel:0' shape=(1, 1) dtype=float32, numpy=array([[5.]], dtype=float32)>, <tf.Variable 'dense_21/bias:0' shape=(1,) dtype=float32, numpy=array([0.], dtype=float32)>]\n"
     ]
    }
   ],
   "source": [
    "class ExampleRandomNormal(tf.keras.initializers.Initializer):\n",
    "\n",
    "    def __init__(self, weights):\n",
    "        self.weights = weights\n",
    "\n",
    "    def __call__(self, shape, dtype=None):\n",
    "        return self.weights\n",
    "\n",
    "    def get_config(self):  # To support serialization\n",
    "        return {'weights': self.weights}\n",
    "\n",
    "# y = 10 * x\n",
    "layer1 = tf.keras.layers.Dense(1, kernel_initializer=ExampleRandomNormal([[10]]), input_shape=[None, 1])\n",
    "model1 = tf.keras.Sequential([layer1])\n",
    "layer2 = tf.keras.layers.Dense(1, kernel_initializer=ExampleRandomNormal([[5]]), input_shape=[None, 1])\n",
    "model2 = tf.keras.Sequential([layer2])\n",
    "\n",
    "def process(x):\n",
    "    x = model1(x)\n",
    "    x = model2(x)\n",
    "    return x\n",
    "\n",
    "x = tf.Variable([[1]], dtype=tf.float32)\n",
    "# persistent=True！！！\n",
    "with tf.GradientTape(watch_accessed_variables=True, persistent=True) as tape:\n",
    "    y = process(x)\n",
    "    loss = y\n",
    "# 计算第一个模型的梯度\n",
    "gradients = tape.gradient(loss, model1.trainable_variables)\n",
    "print ('gradients = ', gradients)\n",
    "print ('model1 parameters before update = ', model1.weights)\n",
    "print ('model2 parameters before update = ', model2.weights)\n",
    "# 更新第一个模型的参数\n",
    "optimizer.apply_gradients(zip(gradients, model1.trainable_variables))\n",
    "print ('model parameters after update = ', model1.weights)\n",
    "print ('model parameters after update = ', model2.weights)\n",
    "\n",
    "# 计算第二个模型的梯度\n",
    "gradients = tape.gradient(loss, model2.trainable_variables)\n",
    "print ('gradients = ', gradients)\n",
    "print ('model1 parameters before update = ', model1.weights)\n",
    "print ('model2 parameters before update = ', model2.weights)\n",
    "# 更新第一个模型的参数\n",
    "optimizer.apply_gradients(zip(gradients, model1.trainable_variables))\n",
    "print ('model parameters after update = ', model1.weights)\n",
    "print ('model parameters after update = ', model2.weights)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 重要！！！\n",
    "\n",
    "process函数和loss函数中必须全部使用tf函数，不能使用相同作用的np函数，否则无法计算梯度。  \n",
    "普通的运算符可以直接使用，不必写成tf函数的形式。  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
